{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c8d01073-09a8-4c68-b93b-aef463738bd0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os, sys, glob, argparse\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, KFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "\n",
    "# Check if GPU is available\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\")\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "70d2263d-916c-431e-aca7-3e5bcd867f9b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "train_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/train/*/*')\n",
    "test_path = glob.glob('./Stable Diffusion鉴别器挑战赛公开数据/test/*')\n",
    "\n",
    "np.random.shuffle(train_path)\n",
    "test_path.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "584f2188-cf45-4d72-abee-4dae0d362926",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiDataset(Dataset):\n",
    "    def __init__(self, img_path, transform=None):\n",
    "        self.img_path = img_path\n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    def __getitem__(self, index):\n",
    "        img = cv2.imread(self.img_path[index])\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(image = img)['image']\n",
    "        img = img.transpose([2,0,1])\n",
    "        return img, torch.from_numpy(np.array(int('real' in self.img_path[index])))\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "bd5a22b6-3a32-4039-bcc7-a6c6f5f9cbaf",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class XunFeiNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(XunFeiNet, self).__init__()\n",
    "        model = models.resnet34(True)\n",
    "        model.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model.fc = nn.Linear(512, 2)\n",
    "        self.resnet = model\n",
    "    def forward(self, img):\n",
    "        out = self.resnet(img)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "47539bef-14a7-4881-85e9-0882ff295e6a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer):\n",
    "    model.train()\n",
    "    train_loss = 0.0\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.to(device)\n",
    "        target = target.to(device)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            print('Train loss', loss.item())\n",
    "            \n",
    "        train_loss += loss.item()\n",
    "    \n",
    "    return train_loss/len(train_loader)\n",
    "            \n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = 0.0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            input = input.to(device)\n",
    "            target = target.to(device)\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "            \n",
    "            val_acc += (output.argmax(1) == target).sum().item()\n",
    "            \n",
    "    return val_acc / len(val_loader.dataset)\n",
    "\n",
    "def predict(test_loader, model, criterion):\n",
    "    model.eval()\n",
    "    val_acc = 0.0\n",
    "    \n",
    "    test_pred = []\n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(test_loader):\n",
    "            input = input.to(device)\n",
    "            target = target.to(device)\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            test_pred.append(output.data.cpu().numpy())\n",
    "            \n",
    "    return np.vstack(test_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "c14b410f-de94-4f52-be5b-1377c34bda4a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import albumentations as A\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path[:-1000],\n",
    "            A.Compose([\n",
    "            A.RandomRotate90(),\n",
    "            A.Resize(256, 256),\n",
    "            A.RandomCrop(224, 224),\n",
    "            A.HorizontalFlip(p=0.5),\n",
    "            A.RandomContrast(p=0.5),\n",
    "            A.RandomBrightnessContrast(p=0.5),\n",
    "            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
    "        ])\n",
    "    ), batch_size=30, shuffle=True, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(train_path[-1000:],\n",
    "            A.Compose([\n",
    "            A.Resize(256, 256),\n",
    "            A.RandomCrop(224, 224),\n",
    "            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
    "        ])\n",
    "    ), batch_size=30, shuffle=False, num_workers=1, pin_memory=False\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XunFeiDataset(test_path,\n",
    "            A.Compose([\n",
    "            A.Resize(256, 256),\n",
    "            A.RandomCrop(224, 224),\n",
    "            A.HorizontalFlip(p=0.5),\n",
    "            A.RandomContrast(p=0.5),\n",
    "            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))\n",
    "        ])\n",
    "    ), batch_size=2, shuffle=False, num_workers=1, pin_memory=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "7f7ee8cb-cfc0-4eb8-96ea-7302144ae7a4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model = XunFeiNet()\n",
    "model = model.to(device)\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.AdamW(model.parameters(), 0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b2faad78-e794-4f31-8b56-1a64dbf89ccc",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss 0.7959139347076416\n",
      "Train loss 0.29768866300582886\n",
      "Train loss 0.2745967209339142\n",
      "0.3281211804474394 0.9055555555555556 0.914\n",
      "Train loss 0.3272189795970917\n",
      "Train loss 0.14214976131916046\n",
      "Train loss 0.2875927686691284\n",
      "0.2151400999724865 0.9335555555555556 0.943\n"
     ]
    }
   ],
   "source": [
    "for _  in range(2):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer)\n",
    "    val_acc  = validate(val_loader, model, criterion)\n",
    "    train_acc = validate(train_loader, model, criterion)\n",
    "    \n",
    "    print(train_loss, train_acc, val_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "6787d12a-b923-4777-abf5-4d90904551d1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "pred = None\n",
    "\n",
    "for _ in range(3):\n",
    "    if pred is None:\n",
    "        pred = predict(test_loader, model, criterion)\n",
    "    else:\n",
    "        pred += predict(test_loader, model, criterion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "1743990f-d27d-43cd-85fd-23a0ff0f2c36",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit = pd.DataFrame(\n",
    "    {\n",
    "        'uuid': [x.split('/')[-1] for x in test_path],\n",
    "        'label': pred.argmax(1)\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "cc565971-7a5f-4bb0-8ab5-c33281bc469c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "submit['label'] = submit['label'].map({1: 'real', 0: 'ai'})\n",
    "submit.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2022bdd-5cee-4966-921a-aab21dd0ef45",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "state": {},
    "version_major": 2,
    "version_minor": 0
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
